{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最大化动作的熵,增强模型的稳定性\n",
    "\n",
    "Q(state,action) + alpha * 熵\\[Q(static,*)\\]\n",
    "\n",
    "训练过程中alpha应该递减.\n",
    "\n",
    "为了缓解自举,会用两个value模型评估Q函数,取其中小的值."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT2ElEQVR4nO3dbWxT99kG8MtOYkNejtOQxm5ELDKVjka8tQHCaT+saj0yGlVjTaW1Ql1aISqYg0ozoWeRWqp2m1IxaXTtKHxZodPUsSeT2NQISqPQBu3BkBKaRyGUPO1Glwiw3ZblOC/ESez7+YByVkOgceL4b+PrJx2Jc/637fsc4yvnxS8WEREQESlgVd0AEWUuBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESmjLID27NmDRYsWYd68eaiqqkJHR4eqVohIESUB9Oc//xkNDQ14+eWXcebMGaxYsQLV1dUIBoMq2iEiRSwqPoxaVVWF1atX43e/+x0AIBqNoqysDNu2bcPPf/7zZLdDRIpkJ/sBx8bG0NnZicbGRnOZ1WqFx+OBz+eb8jbhcBjhcNicj0ajuHLlChYsWACLxTLnPRNRfEQEg4ODKC0thdV68wOtpAfQV199hUgkAqfTGbPc6XTi/PnzU96mqakJr7zySjLaI6IE6u/vx8KFC286nvQAmonGxkY0NDSY84ZhwO12o7+/H5qmKeyMiKYSCoVQVlaGgoKCW9YlPYCKi4uRlZWFQCAQszwQCMDlck15G7vdDrvdfsNyTdMYQEQp7NtOkST9KpjNZkNlZSXa2trMZdFoFG1tbdB1PdntEJFCSg7BGhoaUFdXh1WrVmHNmjV4/fXXMTw8jGeffVZFO0SkiJIA+vGPf4wvv/wSO3fuhN/vx8qVK/H+++/fcGKaiG5vSt4HNFuhUAgOhwOGYfAcEFEKmu5rlJ8FIyJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTJxB9Dx48fx2GOPobS0FBaLBX/9619jxkUEO3fuxF133YX58+fD4/Hgs88+i6m5cuUKNm7cCE3TUFhYiE2bNmFoaGhWK0JE6SfuABoeHsaKFSuwZ8+eKcd37dqFN954A/v27cOpU6eQl5eH6upqjI6OmjUbN25ET08PWltb0dLSguPHj+O5556b+VoQUXqSWQAghw4dMuej0ai4XC759a9/bS4bGBgQu90uf/rTn0RE5Ny5cwJAPv74Y7PmyJEjYrFY5OLFi9N6XMMwBIAYhjGb9olojkz3NZrQc0AXLlyA3++Hx+MxlzkcDlRVVcHn8wEAfD4fCgsLsWrVKrPG4/HAarXi1KlTU95vOBxGKBSKmYgo/SU0gPx+PwDA6XTGLHc6neaY3+9HSUlJzHh2djaKiorMmus1NTXB4XCYU1lZWSLbJiJF0uIqWGNjIwzDMKf+/n7VLRFRAiQ0gFwuFwAgEAjELA8EAuaYy+VCMBiMGZ+YmMCVK1fMmuvZ7XZomhYzEVH6S2gAlZeXw+Vyoa2tzVwWCoVw6tQp6LoOANB1HQMDA+js7DRrjh07hmg0iqqqqkS2Q0QpLjveGwwNDeHzzz835y9cuICuri4UFRXB7XZj+/bt+OUvf4nFixejvLwcL730EkpLS7FhwwYAwL333osf/OAH2Lx5M/bt24fx8XHU19fjySefRGlpacJWjIjSQLyX1z788EMBcMNUV1cnItcuxb/00kvidDrFbrfLI488Ir29vTH38fXXX8tTTz0l+fn5ommaPPvsszI4OJjwS3xEpMZ0X6MWERGF+TcjoVAIDocDhmHwfBBRCpruazQtroIR0e2JAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpE/fP8hAlmkQjGOg7i+j4qLksz/kdzNPuVNgVJQMDiJSLRsbR7/tvhENfmsvKH3qGAZQBeAhGyokI0vDXoSgBGECknkQBBlBGYgCRctf2fhhAmYgBROpJlIdgGYoBRMpJlHtAmYoBROrxHFDGiiuAmpqasHr1ahQUFKCkpAQbNmxAb29vTM3o6Ci8Xi8WLFiA/Px81NbWIhAIxNT09fWhpqYGubm5KCkpwY4dOzAxMTH7taG0JAygjBVXALW3t8Pr9eLkyZNobW3F+Pg41q1bh+HhYbPmhRdewHvvvYfm5ma0t7fj0qVLePzxx83xSCSCmpoajI2N4cSJE3jnnXdw4MAB7Ny5M3FrRelFBMJDsMwksxAMBgWAtLe3i4jIwMCA5OTkSHNzs1nz6aefCgDx+XwiInL48GGxWq3i9/vNmr1794qmaRIOh6f1uIZhCAAxDGM27VOKuDoQkDMHGqRj32Zz+vL8/6hui2Zhuq/RWZ0DMgwDAFBUVAQA6OzsxPj4ODwej1mzZMkSuN1u+Hw+AIDP58OyZcvgdDrNmurqaoRCIfT09Ez5OOFwGKFQKGai2wgvw2esGQdQNBrF9u3b8eCDD2Lp0qUAAL/fD5vNhsLCwphap9MJv99v1nwzfCbHJ8em0tTUBIfDYU5lZWUzbZtSkEQj184DfZPFoqYZSqoZB5DX68XZs2dx8ODBRPYzpcbGRhiGYU79/f1z/piUPGPD/0YkfNWct2bbYC8oVtgRJcuMPoxaX1+PlpYWHD9+HAsXLjSXu1wujI2NYWBgIGYvKBAIwOVymTUdHR0x9zd5lWyy5np2ux12u30mrVIauLb3841DMIsFlqwcZf1Q8sS1ByQiqK+vx6FDh3Ds2DGUl5fHjFdWViInJwdtbW3mst7eXvT19UHXdQCAruvo7u5GMBg0a1pbW6FpGioqKmazLnTbsMDCQ7CMENcekNfrxbvvvou//e1vKCgoMM/ZOBwOzJ8/Hw6HA5s2bUJDQwOKioqgaRq2bdsGXdexdu1aAMC6detQUVGBp59+Grt27YLf78eLL74Ir9fLvRwCcO30j8XK98hmgrgCaO/evQCAhx56KGb5/v378cwzzwAAdu/eDavVitraWoTDYVRXV+Ott94ya7OystDS0oKtW7dC13Xk5eWhrq4Or7766uzWhG4jFsDCAMoEFpH0ewtqKBSCw+GAYRjQNE11OzRLA33d+OzIm+Z8li0X9274L8y/4y6FXdFsTPc1yj8zlHosgIV7QBmBzzKlIAvAc0AZgc8ypRwLuAeUKfgsU+qxWBhAGYLPMqUmHoJlBD7LlHosfCNipmAAkXpTvhOEAZQJGECknESj315EtyUGECknElHdAinCACLluAeUuRhApJxEuQeUqRhApNwN34ZIGYMBROpxDyhjMYBIOe4BZS4GECnHc0CZiwFEyvEqWOZiAJFyPATLXAwgUk4iEzHzFouVvwuWIRhApNzVf1+MmbcXFMPKn+XJCAwgUk4isSehLdk5/DR8hmAAUcq59mVkDKBMwACilHPtHJDqLigZGECUeqzcA8oUDCBKOfw+6MwR1y+jEs1EOBzG1atXbzo+Pj4eOz8RgWEYsFizpqzPzc2FzWZLaI+kBgOI5twf/vAHvPLKKzcd3/HEfdCX/OdXUI8ePYrdW3YhepMf7X399dfxxBNPJLxPSj4GEM25oaEhXLx48abjwyPL8PnV+3FlvBR35vwLoaF/4uLFizcNoJGRkblqlZIsroPtvXv3Yvny5dA0DZqmQdd1HDlyxBwfHR2F1+vFggULkJ+fj9raWgQCgZj76OvrQ01NDXJzc1FSUoIdO3ZgYmLi+oeiDPKPqyvx+ci1APq/kSpcGKmAYOrwodtLXAG0cOFCvPbaa+js7MTp06fx8MMP44c//CF6enoAAC+88ALee+89NDc3o729HZcuXcLjjz9u3j4SiaCmpgZjY2M4ceIE3nnnHRw4cAA7d+5M7FpRWhmaKMTkf0WBFaFxDcyfzBBXAD322GN49NFHsXjxYtxzzz341a9+hfz8fJw8eRKGYeD3v/89fvOb3+Dhhx9GZWUl9u/fjxMnTuDkyZMAgA8++ADnzp3DH//4R6xcuRLr16/HL37xC+zZswdjY2NzsoKU+krt/0C2JQxAYLNcxZ3Z/2T+ZIgZnwOKRCJobm7G8PAwdF1HZ2cnxsfH4fF4zJolS5bA7XbD5/Nh7dq18Pl8WLZsGZxOp1lTXV2NrVu3oqenB/fdd19cPZw/fx75+fkzXQVKkusPw6/3yf+2wtH/LwxMlKAoxw//5fO3rL906RLOnTuXyBYpwYaGhqZVF3cAdXd3Q9d1jI6OIj8/H4cOHUJFRQW6urpgs9lQWFgYU+90OuH3+wEAfr8/JnwmxyfHbiYcDiMcDpvzoVAIAGAYBs8fpYFbXYIHgPauLwB8Me37GxkZwcDAwGxaojk2PDw8rbq4A+i73/0uurq6YBgG/vKXv6Curg7t7e1xNxiPpqamKS/jVlVVQdO0OX1smr1Tp04l9P7uvvtuPPDAAwm9T0qsyZ2EbxP3W05tNhvuvvtuVFZWoqmpCStWrMBvf/tbuFwujI2N3fCXKRAIwOVyAQBcLtcNu+OT85M1U2lsbIRhGObU398fb9tElIJm/Z73aDSKcDiMyspK5OTkoK2tzRzr7e1FX18fdF0HAOi6ju7ubgSDQbOmtbUVmqahoqLipo9ht9vNS/+TExGlv7gOwRobG7F+/Xq43W4MDg7i3XffxUcffYSjR4/C4XBg06ZNaGhoQFFRETRNw7Zt26DrOtauXQsAWLduHSoqKvD0009j165d8Pv9ePHFF+H1emG32+dkBYkodcUVQMFgED/5yU9w+fJlOBwOLF++HEePHsX3v/99AMDu3bthtVpRW1uLcDiM6upqvPXWW+bts7Ky0NLSgq1bt0LXdeTl5aGurg6vvvpqYteKUsrkHmyi8HNgtw+LyE3e757CQqEQHA4HDMPg4VgaGBwcTOhVq6KiIuTl5SXs/ijxpvsa5WfBaM4VFBSgoKBAdRuUgvjFK0SkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEiZbNUNzISIAABCoZDiTohoKpOvzcnX6s2kZQB9/fXXAICysjLFnRDRrQwODsLhcNx0PC0DqKioCADQ19d3y5WjWKFQCGVlZejv74emaarbSQvcZjMjIhgcHERpaekt69IygKzWa6euHA4H/1PMgKZp3G5x4jaL33R2DngSmoiUYQARkTJpGUB2ux0vv/wy7Ha76lbSCrdb/LjN5pZFvu06GRHRHEnLPSAiuj0wgIhIGQYQESnDACIiZdIygPbs2YNFixZh3rx5qKqqQkdHh+qWlGlqasLq1atRUFCAkpISbNiwAb29vTE1o6Oj8Hq9WLBgAfLz81FbW4tAIBBT09fXh5qaGuTm5qKkpAQ7duzAxMREMldFmddeew0WiwXbt283l3GbJYmkmYMHD4rNZpO3335benp6ZPPmzVJYWCiBQEB1a0pUV1fL/v375ezZs9LV1SWPPvqouN1uGRoaMmu2bNkiZWVl0tbWJqdPn5a1a9fKAw88YI5PTEzI0qVLxePxyCeffCKHDx+W4uJiaWxsVLFKSdXR0SGLFi2S5cuXy/PPP28u5zZLjrQLoDVr1ojX6zXnI5GIlJaWSlNTk8KuUkcwGBQA0t7eLiIiAwMDkpOTI83NzWbNp59+KgDE5/OJiMjhw4fFarWK3+83a/bu3Suapkk4HE7uCiTR4OCgLF68WFpbW+V73/ueGUDcZsmTVodgY2Nj6OzshMfjMZdZrVZ4PB74fD6FnaUOwzAA/OcDu52dnRgfH4/ZZkuWLIHb7Ta3mc/nw7Jly+B0Os2a6upqhEIh9PT0JLH75PJ6vaipqYnZNgC3WTKl1YdRv/rqK0QikZgnHQCcTifOnz+vqKvUEY1GsX37djz44INYunQpAMDv98Nms6GwsDCm1ul0wu/3mzVTbdPJsdvRwYMHcebMGXz88cc3jHGbJU9aBRDdmtfrxdmzZ/H3v/9ddSsprb+/H88//zxaW1sxb9481e1ktLQ6BCsuLkZWVtYNVyMCgQBcLpeirlJDfX09Wlpa8OGHH2LhwoXmcpfLhbGxMQwMDMTUf3ObuVyuKbfp5NjtprOzE8FgEPfffz+ys7ORnZ2N9vZ2vPHGG8jOzobT6eQ2S5K0CiCbzYbKykq0tbWZy6LRKNra2qDrusLO1BER1NfX49ChQzh27BjKy8tjxisrK5GTkxOzzXp7e9HX12duM13X0d3djWAwaNa0trZC0zRUVFQkZ0WS6JFHHkF3dze6urrMadWqVdi4caP5b26zJFF9FjxeBw8eFLvdLgcOHJBz587Jc889J4WFhTFXIzLJ1q1bxeFwyEcffSSXL182p5GREbNmy5Yt4na75dixY3L69GnRdV10XTfHJy8pr1u3Trq6uuT999+XO++8M6MuKX/zKpgIt1mypF0AiYi8+eab4na7xWazyZo1a+TkyZOqW1IGwJTT/v37zZqrV6/KT3/6U7njjjskNzdXfvSjH8nly5dj7ueLL76Q9evXy/z586W4uFh+9rOfyfj4eJLXRp3rA4jbLDn4dRxEpExanQMiotsLA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlLm/wFabp8waHiGyQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4089, 0.5911],\n",
       "        [0.4615, 0.5385]], grad_fn=<SoftmaxBackward0>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model_action = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "model_action(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1314, -0.0841],\n",
       "        [-0.0564, -0.0505]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_value1 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    ")\n",
    "\n",
    "model_value2 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    ")\n",
    "\n",
    "model_value1_next = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    ")\n",
    "\n",
    "model_value2_next = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    ")\n",
    "\n",
    "model_value1_next.load_state_dict(model_value1.state_dict())\n",
    "model_value2_next.load_state_dict(model_value2.state_dict())\n",
    "\n",
    "model_value1(torch.randn(2, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        prob = model_action(torch.FloatTensor(state).reshape(1, 4))[0].tolist()\n",
    "        action = random.choices(range(2), weights=prob, k=1)[0]\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_12007/3891364554.py:27: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([-0.03761492, -0.02470117,  0.00951082,  0.02469873], dtype=float32),\n",
       "  1,\n",
       "  1.0,\n",
       "  array([-0.03810894,  0.17028311,  0.01000479, -0.2649683 ], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据池\n",
    "class Pool:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.pool = []\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.pool)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.pool[i]\n",
    "\n",
    "    #更新动作池\n",
    "    def update(self):\n",
    "        #每次更新不少于N条新数据\n",
    "        old_len = len(self.pool)\n",
    "        while len(pool) - old_len < 200:\n",
    "            self.pool.extend(play()[0])\n",
    "\n",
    "        #只保留最新的N条数据\n",
    "        self.pool = self.pool[-2_0000:]\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def sample(self):\n",
    "        data = random.sample(self.pool, 64)\n",
    "\n",
    "        state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)\n",
    "        action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1)\n",
    "        reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)\n",
    "        next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 4)\n",
    "        over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "pool = Pool()\n",
    "pool.update()\n",
    "state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "len(pool), pool[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=2e-4)\n",
    "optimizer_value1 = torch.optim.Adam(model_value1.parameters(), lr=2e-3)\n",
    "optimizer_value2 = torch.optim.Adam(model_value2.parameters(), lr=2e-3)\n",
    "\n",
    "\n",
    "def soft_update(_from, _to):\n",
    "    for _from, _to in zip(_from.parameters(), _to.parameters()):\n",
    "        value = _to.data * 0.995 + _from.data * 0.005\n",
    "        _to.data.copy_(value)\n",
    "\n",
    "\n",
    "def get_prob_entropy(state):\n",
    "    prob = model_action(torch.FloatTensor(state).reshape(-1, 4))\n",
    "    entropy = prob * (prob + 1e-8).log()\n",
    "    entropy = -entropy.sum(dim=1, keepdim=True)\n",
    "\n",
    "    return prob, entropy\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)\n",
    "\n",
    "\n",
    "alpha = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.5886831283569336"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, action, reward, next_state, over):\n",
    "    requires_grad(model_value1, True)\n",
    "    requires_grad(model_value2, True)\n",
    "    requires_grad(model_action, False)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        #计算动作的熵\n",
    "        prob, entropy = get_prob_entropy(next_state)\n",
    "        target1 = model_value1_next(next_state)\n",
    "        target2 = model_value2_next(next_state)\n",
    "        target = torch.min(target1, target2)\n",
    "\n",
    "    #加权熵,熵越大越好\n",
    "    target = (prob * target).sum(dim=1, keepdim=True)\n",
    "    target = target + alpha * entropy\n",
    "    target = target * 0.98 * (1 - over) + reward\n",
    "\n",
    "    #计算value\n",
    "    value = model_value1(state).gather(dim=1, index=action)\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "    loss.backward()\n",
    "    optimizer_value1.step()\n",
    "    optimizer_value1.zero_grad()\n",
    "\n",
    "    value = model_value2(state).gather(dim=1, index=action)\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "    loss.backward()\n",
    "    optimizer_value2.step()\n",
    "    optimizer_value2.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_value(state, action, reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.7090209126472473"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state):\n",
    "    requires_grad(model_value1, False)\n",
    "    requires_grad(model_value2, False)\n",
    "    requires_grad(model_action, True)\n",
    "\n",
    "    #计算熵\n",
    "    prob, entropy = get_prob_entropy(state)\n",
    "\n",
    "    #计算value\n",
    "    value1 = model_value1(state)\n",
    "    value2 = model_value2(state)\n",
    "    value = torch.min(value1, value2)\n",
    "\n",
    "    #求期望求和\n",
    "    value = (prob * value).sum(dim=1, keepdim=True)\n",
    "\n",
    "    #加权熵\n",
    "    loss = -(value + alpha * entropy).mean()\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_action.step()\n",
    "    optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 426 0.9 22.2\n",
      "10 2870 0.31381059609000017 163.05\n",
      "20 6006 0.10941898913151243 181.9\n",
      "30 9576 0.03815204244769462 176.25\n",
      "40 12623 0.013302794647291147 168.75\n",
      "50 15774 0.004638397686588107 162.0\n",
      "60 18964 0.0016173092699229901 182.6\n",
      "70 20000 0.0005639208733960181 192.65\n",
      "80 20000 0.00019662705047555326 199.2\n",
      "90 20000 6.85596132412799e-05 165.8\n",
      "100 20000 2.390525899882879e-05 199.1\n",
      "110 20000 8.335248417898115e-06 148.0\n",
      "120 20000 2.9063214161987086e-06 159.8\n",
      "130 20000 1.0133716178293888e-06 163.1\n",
      "140 20000 3.5334083494636473e-07 169.85\n",
      "150 20000 1.2320233115273002e-07 177.8\n",
      "160 20000 4.295799664301754e-08 168.25\n",
      "170 20000 1.4978527259308396e-08 199.45\n",
      "180 20000 5.222689519770981e-09 199.05\n",
      "190 20000 1.821039234880364e-09 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    global alpha\n",
    "    model_action.train()\n",
    "    model_value1.train()\n",
    "    model_value2.train()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        pool.update()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = pool.sample()\n",
    "\n",
    "            #训练\n",
    "            train_value(state, action, reward, next_state, over)\n",
    "            train_action(state)\n",
    "            soft_update(model_value1, model_value1_next)\n",
    "            soft_update(model_value2, model_value2_next)\n",
    "\n",
    "        alpha *= 0.9\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, len(pool), alpha, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT4klEQVR4nO3dX0xb5/0G8McGbALkmBKGXQRuMq1ayi//VpLAaS82tV5ohqJm4WKboo5VUbJmJiqlilakJl27TkTZRbduKbnZkl4sy8SktCpL2jHSEk1xQkOLRCBhq5QKmsR2E4YNJNiAv7+LiNM4ISkmhNfHPB/pSJzzvra/5zV+OOe8trGIiICISAGr6gKIaP5iABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTLKAmjfvn1YvHgxMjMzUVZWhvb2dlWlEJEiSgLob3/7G+rq6vDKK6/gk08+wcqVK1FRUYFgMKiiHCJSxKLiw6hlZWVYs2YN/vjHPwIAYrEYiouLsWPHDrz00ktzXQ4RKZI+1w8YjUbR0dGB+vp6Y5vVaoXH44HP55vyNpFIBJFIxFiPxWIYGBjAokWLYLFY7nvNRJQYEcHQ0BAKCwthtd75RGvOA+jKlSuYmJiA0+mM2+50OnH+/Pkpb9PQ0IBXX311LsojolnU39+PoqKiO7bPeQDNRH19Perq6oz1UCgEt9uN/v5+aJqmsDIimko4HEZxcTEWLlx4135zHkD5+flIS0tDIBCI2x4IBOByuaa8jd1uh91uv227pmkMIKIk9nWXSOZ8Fsxms6G0tBStra3GtlgshtbWVui6PtflEJFCSk7B6urqUF1djdWrV2Pt2rX43e9+h5GRETz77LMqyiEiRZQE0I9+9CN8+eWX2L17N/x+P1atWoX333//tgvTRJTalLwP6F6Fw2E4HA6EQiFeAyJKQtN9jfKzYESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlIm4QA6ceIENmzYgMLCQlgsFrzzzjtx7SKC3bt348EHH8SCBQvg8Xjw3//+N67PwMAANm/eDE3TkJubiy1btmB4ePiedoSIzCfhABoZGcHKlSuxb9++Kdv37t2LN998E/v378fp06eRnZ2NiooKjI6OGn02b96M7u5utLS0oLm5GSdOnMC2bdtmvhdEZE5yDwDIkSNHjPVYLCYul0t++9vfGtsGBwfFbrfLX//6VxER6enpEQDy8ccfG32OHTsmFotFLl68OK3HDYVCAkBCodC9lE9E98l0X6Ozeg3owoUL8Pv98Hg8xjaHw4GysjL4fD4AgM/nQ25uLlavXm308Xg8sFqtOH369JT3G4lEEA6H4xYiMr9ZDSC/3w8AcDqdcdudTqfR5vf7UVBQENeenp6OvLw8o8+tGhoa4HA4jKW4uHg2yyYiRUwxC1ZfX49QKGQs/f39qksiolkwqwHkcrkAAIFAIG57IBAw2lwuF4LBYFz7+Pg4BgYGjD63stvt0DQtbiEi85vVAFqyZAlcLhdaW1uNbeFwGKdPn4au6wAAXdcxODiIjo4Oo8/x48cRi8VQVlY2m+UQUZJLT/QGw8PD+Oyzz4z1CxcuoLOzE3l5eXC73aitrcXrr7+Ohx9+GEuWLMGuXbtQWFiIjRs3AgAeeeQRPPXUU9i6dSv279+PsbEx1NTU4Mc//jEKCwtnbceIyAQSnV778MMPBcBtS3V1tYjcmIrftWuXOJ1Osdvt8uSTT0pvb2/cfVy9elV+8pOfSE5OjmiaJs8++6wMDQ3N+hQfEakx3deoRUREYf7NSDgchsPhQCgU4vUgoiQ03deoKWbBiCg1MYCISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRL+tzxEyUhEEP6iG+OjI8a2rPxiLHiA/+opmTGAKCWIxPBF+zu4dqXP2FZUtokBlOR4CkapQWKA+f7D1LzHAKLUIAIBA8hsGECUEkSER0AmxACi1CACkZjqKihBDCBKCcJrQKbEAKIUwVMwM0oogBoaGrBmzRosXLgQBQUF2LhxI3p7e+P6jI6Owuv1YtGiRcjJyUFVVRUCgUBcn76+PlRWViIrKwsFBQXYuXMnxsfH731vaN4SkRvXgchUEgqgtrY2eL1enDp1Ci0tLRgbG8O6deswMvLVm79eeOEFvPfee2hqakJbWxsuXbqETZs2Ge0TExOorKxENBrFyZMn8fbbb+PgwYPYvXv37O0VzT8iAGfBzEfuQTAYFADS1tYmIiKDg4OSkZEhTU1NRp9z584JAPH5fCIicvToUbFareL3+40+jY2NommaRCKRaT1uKBQSABIKhe6lfEoh0ZFB6fzLS9K+f6uxXPr0mOqy5q3pvkbv6RpQKBQCAOTl5QEAOjo6MDY2Bo/HY/RZunQp3G43fD4fAMDn82H58uVwOp1Gn4qKCoTDYXR3d0/5OJFIBOFwOG4huplwGt6UZhxAsVgMtbW1ePzxx7Fs2TIAgN/vh81mQ25ublxfp9MJv99v9Lk5fCbbJ9um0tDQAIfDYSzFxcUzLZtSlEgMEoufhrdYOMeS7Gb8DHm9Xpw9exaHDx+ezXqmVF9fj1AoZCz9/f33/THJXMavD2F8dMhYt1jTYHc473ILSgYz+jBqTU0NmpubceLECRQVFRnbXS4XotEoBgcH446CAoEAXC6X0ae9vT3u/iZnySb73Mput8Nut8+kVJonRGK3vBHRAmt6hrJ6aHoSOgISEdTU1ODIkSM4fvw4lixZEtdeWlqKjIwMtLa2Gtt6e3vR19cHXdcBALquo6urC8Fg0OjT0tICTdNQUlJyL/tC9BULT8HMIKEjIK/Xi0OHDuHdd9/FwoULjWs2DocDCxYsgMPhwJYtW1BXV4e8vDxomoYdO3ZA13WUl5cDANatW4eSkhI888wz2Lt3L/x+P15++WV4vV4e5dAssjCATCChAGpsbAQAfO9734vbfuDAAfzsZz8DALzxxhuwWq2oqqpCJBJBRUUF3nrrLaNvWloampubsX37dui6juzsbFRXV+O11167tz0hupWVAZTsLCLmm7sMh8NwOBwIhULQNE11OZQEhoMXcO6dPcZUvCUtA0s3vIgc5zcVVzY/Tfc1yj8RlJIsuDETRsmNAUSpyWIBeA0o6fEZopRl4TWgpMdniFIWZ8GSH58hSlEWHgGZAJ8hSk03rkKrroK+Bp8hSg1TvJvEYrEoKIQSwQCilHDrJ+HJHBhAlBJEJlSXQDPAAKKUwCMgc2IAUWrgEZApMYAoJUgsxu+kNyEGEKUEnoKZEwOIUgIvQpsTA4hSAo+AzIkBRCmBR0DmxACi1MAjIFNiAFFKiE2M4+ZpMMuND4Mpq4emhwFEKWH0f5fj1jNy8pBmy1RUDU0XA4hSQiw2HrduTUvn9wGZAJ8hSk0WK3gKlvwYQJSSLBbLje+FpqTGAKKUxNMvc+CzRKnJYuUXkplAQv8ZlUiVaDSKa9eu3bk9Eolbn5iYQCgUhiVt6l/xzMxMZGZylkw1BhCZwj/+8Q/s2LHjju3b1v8fnip9yFg/dbodr9fuw3hs6o/I79q1Cz//+c9nvU5KDAOITOHatWu4ePHiHduHh7+Jz0eXIxh9CHkZlzB8/Si+uHgRE3cIoKGhoftVKiUgoWtAjY2NWLFiBTRNg6Zp0HUdx44dM9pHR0fh9XqxaNEi5OTkoKqqCoFAIO4++vr6UFlZiaysLBQUFGDnzp0YHx+/9aGIEtI3WoLekTIMjBXis2uP4j/Dq6b6nnpKMgkFUFFREfbs2YOOjg6cOXMGTzzxBJ5++ml0d3cDAF544QW89957aGpqQltbGy5duoRNmzYZt5+YmEBlZSWi0ShOnjyJt99+GwcPHsTu3btnd69o3hmeyIUYv85WhMcdEH5DWdJL6BRsw4YNceu/+c1v0NjYiFOnTqGoqAh/+tOfcOjQITzxxBMAgAMHDuCRRx7BqVOnUF5ejn/+85/o6enBv/71LzidTqxatQq//vWv8ctf/hK/+tWvYLPZZm/PaF5x2S/AZrmGqCxAuiUKV8ZnPAIygRlfA5qYmEBTUxNGRkag6zo6OjowNjYGj8dj9Fm6dCncbjd8Ph/Ky8vh8/mwfPlyOJ1Oo09FRQW2b9+O7u5ufOc730mohvPnzyMnJ2emu0AmcunSpbu2n+05iav/u4qBsQfhSP8SQwP/uWv/YDCInp6e2SyRbjI8PDytfgkHUFdXF3Rdx+joKHJycnDkyBGUlJSgs7MTNpsNubm5cf2dTif8fj8AwO/3x4XPZPtk251EIhFEbppmDYfDAIBQKMTrR/PE3abgAeB0zxdAzxfTvr/r169jcHDwHquiOxkZGZlWv4QD6Nvf/jY6OzsRCoXw97//HdXV1Whra0u4wEQ0NDTg1VdfvW17WVkZNE27r49NyeHChQuzen8PPfQQHnvssVm9T/rK5EHC10n4ndA2mw3f+ta3UFpaioaGBqxcuRK///3v4XK5EI1Gb/urEggE4HK5AAAul+u2WbHJ9ck+U6mvr0coFDKW/v7+RMsmoiR0zx/FiMViiEQiKC0tRUZGBlpbW4223t5e9PX1Qdd1AICu6+jq6kIwGDT6tLS0QNM0lJSU3PEx7Ha7MfU/uRCR+SV0ClZfX4/169fD7XZjaGgIhw4dwkcffYQPPvgADocDW7ZsQV1dHfLy8qBpGnbs2AFd11FeXg4AWLduHUpKSvDMM89g79698Pv9ePnll+H1emG32+/LDhJR8koogILBIH7605/i8uXLcDgcWLFiBT744AN8//vfBwC88cYbsFqtqKqqQiQSQUVFBd566y3j9mlpaWhubsb27duh6zqys7NRXV2N1157bXb3ilKOzWab1SNf/sFLDhYR871bIhwOw+FwIBQK8XRsnhgZGcHAwMCs3Z/D4eDvzn003dcoPwtGppCdnY3s7GzVZdAs4/cBEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImXSVRcwEyICAAiHw4orIaKpTL42J1+rd2LKALp69SoAoLi4WHElRHQ3Q0NDcDgcd2w3ZQDl5eUBAPr6+u66cxQvHA6juLgY/f390DRNdTmmwDGbGRHB0NAQCgsL79rPlAFktd64dOVwOPhLMQOapnHcEsQxS9x0Dg54EZqIlGEAEZEypgwgu92OV155BXa7XXUppsJxSxzH7P6yyNfNkxER3SemPAIiotTAACIiZRhARKQMA4iIlDFlAO3btw+LFy9GZmYmysrK0N7errokZRoaGrBmzRosXLgQBQUF2LhxI3p7e+P6jI6Owuv1YtGiRcjJyUFVVRUCgUBcn76+PlRWViIrKwsFBQXYuXMnxsfH53JXlNmzZw8sFgtqa2uNbRyzOSImc/jwYbHZbPLnP/9Zuru7ZevWrZKbmyuBQEB1aUpUVFTIgQMH5OzZs9LZ2Sk/+MEPxO12y/DwsNHnueeek+LiYmltbZUzZ85IeXm5PPbYY0b7+Pi4LFu2TDwej3z66ady9OhRyc/Pl/r6ehW7NKfa29tl8eLFsmLFCnn++eeN7RyzuWG6AFq7dq14vV5jfWJiQgoLC6WhoUFhVckjGAwKAGlraxMRkcHBQcnIyJCmpiajz7lz5wSA+Hw+ERE5evSoWK1W8fv9Rp/GxkbRNE0ikcjc7sAcGhoakocfflhaWlrku9/9rhFAHLO5Y6pTsGg0io6ODng8HmOb1WqFx+OBz+dTWFnyCIVCAL76wG5HRwfGxsbixmzp0qVwu93GmPl8PixfvhxOp9PoU1FRgXA4jO7u7jmsfm55vV5UVlbGjQ3AMZtLpvow6pUrVzAxMRH3pAOA0+nE+fPnFVWVPGKxGGpra/H4449j2bJlAAC/3w+bzYbc3Ny4vk6nE36/3+gz1ZhOtqWiw4cP45NPPsHHH398WxvHbO6YKoDo7rxeL86ePYt///vfqktJav39/Xj++efR0tKCzMxM1eXMa6Y6BcvPz0daWtptsxGBQAAul0tRVcmhpqYGzc3N+PDDD1FUVGRsd7lciEajGBwcjOt/85i5XK4px3SyLdV0dHQgGAzi0UcfRXp6OtLT09HW1oY333wT6enpcDqdHLM5YqoAstlsKC0tRWtrq7EtFouhtbUVuq4rrEwdEUFNTQ2OHDmC48ePY8mSJXHtpaWlyMjIiBuz3t5e9PX1GWOm6zq6uroQDAaNPi0tLdA0DSUlJXOzI3PoySefRFdXFzo7O41l9erV2Lx5s/Ezx2yOqL4KnqjDhw+L3W6XgwcPSk9Pj2zbtk1yc3PjZiPmk+3bt4vD4ZCPPvpILl++bCzXrl0z+jz33HPidrvl+PHjcubMGdF1XXRdN9onp5TXrVsnnZ2d8v7778s3vvGNeTWlfPMsmAjHbK6YLoBERP7whz+I2+0Wm80ma9eulVOnTqkuSRkAUy4HDhww+ly/fl1+8YtfyAMPPCBZWVnywx/+UC5fvhx3P59//rmsX79eFixYIPn5+fLiiy/K2NjYHO+NOrcGEMdsbvDrOIhIGVNdAyKi1MIAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlPl/1HCkNRaam3EAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:cuda117]",
   "language": "python",
   "name": "conda-env-cuda117-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
